{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install pytorch_pretrained_bert pytorch-nlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "import random as rn\n",
    "import torch\n",
    "from pytorch_pretrained_bert import BertModel\n",
    "from torch import nn\n",
    "from pytorch_pretrained_bert import BertTokenizer\n",
    "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
    "from torch.optim import Adam\n",
    "from torch.nn.utils import clip_grad_norm_\n",
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "rn.seed(10)\n",
    "np.random.seed(10)\n",
    "torch.manual_seed(10)\n",
    "torch.cuda.manual_seed(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# News Aggregator Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/taeyong/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n",
      "/home/taeyong/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import keras\n",
    "import tensorflow as tf\n",
    "\n",
    "# keras: for data processing\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.utils.np_utils import to_categorical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data load\n",
    "data = pd.read_csv('./data/uci_news_aggregator.csv', delimiter = ',', skiprows = 1,\n",
    "                   names = ['ID', 'TITLE', 'URL', 'PUBLISHER', 'CATEGORY', 'STORY',\n",
    "                            'HOSTNAME', 'TIMESTAMP'], usecols=['TITLE', 'CATEGORY'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data.reindex(np.random.permutation(data.index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(422419, 2)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>TITLE</th>\n",
       "      <th>CATEGORY</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>207842</td>\n",
       "      <td>Dow and S&amp;P 500 index close at record levels</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>223144</td>\n",
       "      <td>U.S. Consumer Price Inflation Accelerates In A...</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>329274</td>\n",
       "      <td>Broadway Will Dim Its Marquee Lights for Mary ...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90363</td>\n",
       "      <td>No real change in childhood obesity rates</td>\n",
       "      <td>m</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>215176</td>\n",
       "      <td>Newly-unveiled footage from 1937 shows FDR wal...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30170</td>\n",
       "      <td>Some missing bitcoins are found</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>245057</td>\n",
       "      <td>Home / Niederauer to resign as NYSE CEO</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60944</td>\n",
       "      <td>GM ignition switch recall expands</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>249542</td>\n",
       "      <td>Piketty: FT's criticism 'ridiculous'</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15711</td>\n",
       "      <td>German deputy finance minister greets court ru...</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>372487</td>\n",
       "      <td>'Extant' review: The final frontier of been th...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>546</td>\n",
       "      <td>Sbarro goes bankrupt for second time</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>178393</td>\n",
       "      <td>Michelle Obama Pays Tribute To Anna Wintour At...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>295723</td>\n",
       "      <td>Los Angeles Mayor Drops F-Bomb on TV -- Here's...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>66916</td>\n",
       "      <td>Climate Change Will Destabilise the World, War...</td>\n",
       "      <td>t</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>147395</td>\n",
       "      <td>Lindsay Lohan talks about 'sex list' on 'WWHL'...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>181924</td>\n",
       "      <td>West Michigan celebrates Cinco de Mayo</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>63406</td>\n",
       "      <td>Lady Gaga Talks About Her Wild, Violent Nights...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22673</td>\n",
       "      <td>A friend of L'Wren Scott says the late fashion...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>235624</td>\n",
       "      <td>Will and Jada Smith Investigated by Children S...</td>\n",
       "      <td>e</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    TITLE CATEGORY\n",
       "207842       Dow and S&P 500 index close at record levels        b\n",
       "223144  U.S. Consumer Price Inflation Accelerates In A...        b\n",
       "329274  Broadway Will Dim Its Marquee Lights for Mary ...        e\n",
       "90363           No real change in childhood obesity rates        m\n",
       "215176  Newly-unveiled footage from 1937 shows FDR wal...        e\n",
       "30170                     Some missing bitcoins are found        b\n",
       "245057            Home / Niederauer to resign as NYSE CEO        b\n",
       "60944                   GM ignition switch recall expands        b\n",
       "249542               Piketty: FT's criticism 'ridiculous'        b\n",
       "15711   German deputy finance minister greets court ru...        b\n",
       "372487  'Extant' review: The final frontier of been th...        e\n",
       "546                  Sbarro goes bankrupt for second time        b\n",
       "178393  Michelle Obama Pays Tribute To Anna Wintour At...        e\n",
       "295723  Los Angeles Mayor Drops F-Bomb on TV -- Here's...        e\n",
       "66916   Climate Change Will Destabilise the World, War...        t\n",
       "147395  Lindsay Lohan talks about 'sex list' on 'WWHL'...        e\n",
       "181924             West Michigan celebrates Cinco de Mayo        e\n",
       "63406   Lady Gaga Talks About Her Wild, Violent Nights...        e\n",
       "22673   A friend of L'Wren Scott says the late fashion...        e\n",
       "235624  Will and Jada Smith Investigated by Children S...        e"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data slicing\n",
    "num_of_categories = 12500\n",
    "shuffled = data.reindex(np.random.permutation(data.index))\n",
    "\n",
    "e = shuffled[shuffled['CATEGORY'] == 'e']#[:num_of_categories]\n",
    "b = shuffled[shuffled['CATEGORY'] == 'b']#[:num_of_categories]\n",
    "t = shuffled[shuffled['CATEGORY'] == 't']#[:num_of_categories]\n",
    "m = shuffled[shuffled['CATEGORY'] == 'm']#[:num_of_categories]\n",
    "\n",
    "concated = pd.concat([e,b,t,m], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# label col\n",
    "concated['LABEL'] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(41)\n",
    "concated = concated.reindex(np.random.permutation(concated.index))\n",
    "\n",
    "concated.loc[concated['CATEGORY'] == 'e', 'LABEL'] = 0\n",
    "concated.loc[concated['CATEGORY'] == 'b', 'LABEL'] = 1\n",
    "concated.loc[concated['CATEGORY'] == 't', 'LABEL'] = 2\n",
    "concated.loc[concated['CATEGORY'] == 'm', 'LABEL'] = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "51683     0\n",
       "364094    2\n",
       "220239    1\n",
       "15418     0\n",
       "162428    1\n",
       "346459    2\n",
       "60046     0\n",
       "331513    2\n",
       "7953      0\n",
       "78829     0\n",
       "Name: LABEL, dtype: int64"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concated['LABEL'][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# one-hot encoding\n",
    "labels = to_categorical(concated['LABEL'], num_classes=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 0., 0., 0.],\n",
       "       [0., 0., 1., 0.],\n",
       "       [0., 1., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 1., 0.],\n",
       "       [1., 0., 0., 0.],\n",
       "       [0., 0., 1., 0.]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_max_len = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([\"Georgina Haig Joins 'Once Upon A Time' as Queen Elsa\",\n",
       "       'Stolen social insurance numbers can cause many problems',\n",
       "       \"Will Twitter's Q1 earnings meet Wall Street expectations?\"],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " concated['TITLE'].values[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = list(map(lambda t: ['[CLS]'] + tokenizer.tokenize(t) + ['[SEP]'], concated['TITLE'].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['[CLS]',\n",
       "  'georgina',\n",
       "  'hai',\n",
       "  '##g',\n",
       "  'joins',\n",
       "  \"'\",\n",
       "  'once',\n",
       "  'upon',\n",
       "  'a',\n",
       "  'time',\n",
       "  \"'\",\n",
       "  'as',\n",
       "  'queen',\n",
       "  'elsa',\n",
       "  '[SEP]'],\n",
       " ['[CLS]',\n",
       "  'stolen',\n",
       "  'social',\n",
       "  'insurance',\n",
       "  'numbers',\n",
       "  'can',\n",
       "  'cause',\n",
       "  'many',\n",
       "  'problems',\n",
       "  '[SEP]'],\n",
       " ['[CLS]',\n",
       "  'will',\n",
       "  'twitter',\n",
       "  \"'\",\n",
       "  's',\n",
       "  'q',\n",
       "  '##1',\n",
       "  'earnings',\n",
       "  'meet',\n",
       "  'wall',\n",
       "  'street',\n",
       "  'expectations',\n",
       "  '?',\n",
       "  '[SEP]']]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (11776 > 512). Running this sequence through BERT will result in indexing errors\n",
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (1404 > 512). Running this sequence through BERT will result in indexing errors\n",
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (9657 > 512). Running this sequence through BERT will result in indexing errors\n",
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (7454 > 512). Running this sequence through BERT will result in indexing errors\n",
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (1069 > 512). Running this sequence through BERT will result in indexing errors\n",
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (3836 > 512). Running this sequence through BERT will result in indexing errors\n",
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (9609 > 512). Running this sequence through BERT will result in indexing errors\n",
      "Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (5445 > 512). Running this sequence through BERT will result in indexing errors\n"
     ]
    }
   ],
   "source": [
    "tokens_ids = pad_sequences(list(map(tokenizer.convert_tokens_to_ids, tokens)), maxlen=word_max_len, truncating=\"post\", padding=\"post\", dtype=\"int\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(422419, 64)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  101, 27358, 15030,  2290,  9794,  1005,  2320,  2588,  1037,\n",
       "         2051,  1005,  2004,  3035, 23452,   102,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "            0]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens_ids[:1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks = [[float(i > 0) for i in ii] for ii in tokens_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  1.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0]]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masks[:1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertClassifier(nn.Module):\n",
    "    def __init__(self, dropout=0.1):\n",
    "        super(BertClassifier, self).__init__()\n",
    "\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased')\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.linear = nn.Linear(768, 4)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "        self.softmax = nn.Softmax()\n",
    "    \n",
    "    def forward(self, tokens, masks=None):\n",
    "        _, pooled_output = self.bert(tokens, attention_mask=masks, output_all_encoded_layers=False)\n",
    "        dropout_output = self.dropout(pooled_output)\n",
    "        linear_output = self.linear(dropout_output)\n",
    "#         proba = self.sigmoid(linear_output)\n",
    "        proba = self.softmax(linear_output)\n",
    "        return proba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.0M'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_clf = BertClassifier()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_clf = bert_clf.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'439.074304M'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor(tokens_ids[:3]).to(device)\n",
    "y, pooled = bert_clf.bert(x, output_all_encoded_layers=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([3, 64]), torch.Size([3, 64, 768]), torch.Size([3, 768]))"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape, y.shape, pooled.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  101, 27358, 15030,  2290,  9794,  1005,  2320,  2588,  1037,  2051,\n",
       "          1005,  2004,  3035, 23452,   102,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0],\n",
       "        [  101,  7376,  2591,  5427,  3616,  2064,  3426,  2116,  3471,   102,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0],\n",
       "        [  101,  2097, 10474,  1005,  1055,  1053,  2487, 16565,  3113,  2813,\n",
       "          2395, 10908,  1029,   102,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0]], device='cuda:0')"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-2.8207e-01,  2.4838e-01,  5.8646e-01,  ..., -3.5017e-01,\n",
       "           1.0009e-01, -2.6468e-01],\n",
       "         [-4.0030e-02,  4.1529e-01,  5.6553e-01,  ..., -3.4315e-01,\n",
       "           3.8762e-01, -5.5785e-01],\n",
       "         [ 6.3444e-01,  4.2569e-01,  9.9663e-01,  ...,  2.9993e-01,\n",
       "           7.6073e-01,  8.5302e-01],\n",
       "         ...,\n",
       "         [-2.2172e-02, -1.5674e-02,  6.6832e-01,  ..., -4.5963e-01,\n",
       "           9.9788e-02, -7.3225e-01],\n",
       "         [ 2.8428e-03, -5.7244e-02,  8.5414e-01,  ..., -5.1558e-01,\n",
       "           9.4784e-02, -4.8146e-01],\n",
       "         [ 8.3657e-03, -2.8630e-01,  8.9913e-01,  ..., -6.5501e-01,\n",
       "           1.1099e-01, -2.4487e-01]],\n",
       "\n",
       "        [[-2.0049e-01,  2.1620e-01,  7.0347e-01,  ..., -3.1311e-01,\n",
       "           5.7537e-01, -2.0779e-01],\n",
       "         [ 5.1083e-01, -5.9955e-02,  8.7744e-01,  ..., -2.5219e-01,\n",
       "          -2.4464e-01, -2.3501e-01],\n",
       "         [ 2.9611e-01, -2.2467e-01,  7.1730e-01,  ..., -4.8486e-01,\n",
       "          -1.3786e-01, -6.7392e-01],\n",
       "         ...,\n",
       "         [ 5.6809e-02, -4.1697e-01,  9.6858e-01,  ..., -4.7042e-01,\n",
       "           2.4617e-01, -6.5506e-01],\n",
       "         [-1.4489e-02, -1.4797e-01,  6.3873e-01,  ..., -3.5994e-01,\n",
       "           2.3615e-01, -6.4712e-01],\n",
       "         [-3.8529e-01,  1.8723e-03,  8.1516e-01,  ..., -2.1629e-01,\n",
       "           2.4270e-01, -4.4588e-01]],\n",
       "\n",
       "        [[-2.0840e-01, -6.8944e-02,  5.7200e-01,  ..., -2.9226e-01,\n",
       "           4.4367e-01, -2.9175e-01],\n",
       "         [ 1.7605e-01, -1.5812e-01,  4.5609e-01,  ...,  3.5311e-01,\n",
       "           6.2837e-01, -1.0137e+00],\n",
       "         [ 3.0724e-01, -7.0891e-02,  3.3793e-01,  ..., -4.8285e-01,\n",
       "           2.3006e-01, -8.3535e-01],\n",
       "         ...,\n",
       "         [ 1.2746e-01, -3.3312e-01,  1.0224e+00,  ..., -4.5262e-01,\n",
       "           2.8309e-01, -9.1218e-01],\n",
       "         [ 6.9124e-01,  1.1194e-04,  9.9743e-01,  ..., -6.4581e-01,\n",
       "           3.1526e-01, -1.3636e+00],\n",
       "         [ 1.8478e-01, -1.0766e-01,  7.6723e-01,  ..., -4.7125e-01,\n",
       "           2.1066e-01, -9.0448e-01]]], device='cuda:0', grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5156, -0.3123, -0.6917,  ..., -0.0835, -0.5205,  0.6408],\n",
       "        [-0.4346, -0.3995, -0.7330,  ..., -0.1952, -0.5318,  0.5257],\n",
       "        [-0.4422, -0.2517, -0.6401,  ..., -0.1717, -0.4710,  0.4387]],\n",
       "       device='cuda:0', grad_fn=<TanhBackward>)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pooled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/taeyong/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    }
   ],
   "source": [
    "y = bert_clf(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.31779057, 0.16687712, 0.16944495, 0.34588736],\n",
       "       [0.3085278 , 0.24000305, 0.20457213, 0.24689703],\n",
       "       [0.30087715, 0.2129052 , 0.17448093, 0.3117367 ]], dtype=float32)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'894.48704M'"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, x, pooled = None, None, None\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'667.20768M'"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tune BERT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Train / Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 3\n",
    "EPOCHS = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_size = 4 * int(len(tokens_ids) / 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_ids, X_test_ids = tokens_ids[:split_size,:], tokens_ids[split_size:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(422419, 64)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(337932, 64)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(84487, 64)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train, y_test = labels[:split_size,:], labels[split_size:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(422419, 4)"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(337932, 4)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(84487, 4)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks_train, masks_test = np.array(masks)[:split_size,:], np.array(masks)[split_size:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(422419, 64)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(masks).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(337932, 64)"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masks_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(84487, 64)"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masks_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_tokens_tensor = torch.tensor(X_train_ids)\n",
    "train_y_tensor = torch.tensor(y_train).float()\n",
    "\n",
    "test_tokens_tensor = torch.tensor(X_test_ids)\n",
    "test_y_tensor = torch.tensor(y_test).float()\n",
    "\n",
    "train_masks_tensor = torch.tensor(masks_train)\n",
    "test_masks_tensor = torch.tensor(masks_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'667.20768M'"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)\n",
    "train_sampler = RandomSampler(train_dataset)\n",
    "train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)\n",
    "\n",
    "test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)\n",
    "test_sampler = SequentialSampler(test_dataset)\n",
    "test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "112644"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "28163"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(test_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[  101, 27358, 15030,  ...,     0,     0,     0],\n",
       "         [  101,  7376,  2591,  ...,     0,     0,     0],\n",
       "         [  101,  2097, 10474,  ...,     0,     0,     0],\n",
       "         ...,\n",
       "         [  101,  6106,  2015,  ...,     0,     0,     0],\n",
       "         [  101, 13753,  3976,  ...,     0,     0,     0],\n",
       "         [  101,  8861,  2402,  ...,     0,     0,     0]]),\n",
       " tensor([[1., 1., 1.,  ..., 0., 0., 0.],\n",
       "         [1., 1., 1.,  ..., 0., 0., 0.],\n",
       "         [1., 1., 1.,  ..., 0., 0., 0.],\n",
       "         ...,\n",
       "         [1., 1., 1.,  ..., 0., 0., 0.],\n",
       "         [1., 1., 1.,  ..., 0., 0., 0.],\n",
       "         [1., 1., 1.,  ..., 0., 0., 0.]], dtype=torch.float64),\n",
       " tensor([[1., 0., 0., 0.],\n",
       "         [0., 0., 1., 0.],\n",
       "         [0., 1., 0., 0.],\n",
       "         ...,\n",
       "         [1., 0., 0., 0.],\n",
       "         [0., 1., 0., 0.],\n",
       "         [1., 0., 0., 0.]]))"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.tensors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Fine-tune BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = Adam(bert_clf.parameters(), lr=3e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  3\n",
      "\r",
      "112643/112644 loss: 0.06198843557304419 \n"
     ]
    }
   ],
   "source": [
    "for epoch_num in range(EPOCHS):\n",
    "    bert_clf.train()\n",
    "    train_loss = 0\n",
    "    for step_num, batch_data in enumerate(train_dataloader):\n",
    "        \n",
    "        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)\n",
    "        logits = bert_clf(token_ids, masks)\n",
    "        \n",
    "        loss_func = nn.BCELoss()\n",
    "\n",
    "        batch_loss = loss_func(logits, labels)\n",
    "        train_loss += batch_loss.item()\n",
    "        \n",
    "        bert_clf.zero_grad()\n",
    "        batch_loss.backward()\n",
    "\n",
    "        clip_grad_norm_(parameters=bert_clf.parameters(), max_norm=1.0)\n",
    "        optimizer.step()\n",
    "        \n",
    "        clear_output(wait=True)\n",
    "        print('Epoch: ', epoch_num + 1)\n",
    "        print(\"\\r\" + \"{0}/{1} loss: {2} \".format(step_num, int(len(X_train_ids) / BATCH_SIZE), train_loss / (step_num + 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/taeyong/anaconda3/lib/python3.6/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type BertClassifier. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    }
   ],
   "source": [
    "torch.save(bert_clf, './bert_clf_' + str(EPOCHS) + 'epoch' + '.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "28162/28162 loss: 0.24793613381706459 \n"
     ]
    }
   ],
   "source": [
    "bert_clf.eval()\n",
    "bert_predicted = []\n",
    "all_logits = []\n",
    "with torch.no_grad():\n",
    "    for step_num, batch_data in enumerate(test_dataloader):\n",
    "\n",
    "        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)\n",
    "        logits = bert_clf(token_ids, masks)\n",
    "        \n",
    "        loss_func = nn.BCELoss()\n",
    "        \n",
    "        loss = loss_func(logits, labels)\n",
    "        numpy_logits = logits.cpu().detach().numpy()\n",
    "        \n",
    "        for i in range(len(logits)):\n",
    "             bert_predicted.append(logits[i].argmax())\n",
    "#         bert_predicted += list(numpy_logits[:, 0] > 0.5)\n",
    "        all_logits += list(numpy_logits[:, 0])\n",
    "        \n",
    "        clear_output(wait=True)\n",
    "        print(\"\\r\" + \"{0}/{1} loss: {2} \".format(step_num, int(len(X_test_ids) / BATCH_SIZE), train_loss / (step_num + 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_predicted = np.array(bert_predicted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "84487"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(bert_predicted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([tensor(0, device='cuda:0'), tensor(0, device='cuda:0'),\n",
       "       tensor(2, device='cuda:0'), ..., tensor(2, device='cuda:0'),\n",
       "       tensor(0, device='cuda:0'), tensor(2, device='cuda:0')],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_predicted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 0., 0., 0.],\n",
       "       [1., 0., 0., 0.],\n",
       "       [0., 0., 1., 0.],\n",
       "       ...,\n",
       "       [0., 0., 1., 0.],\n",
       "       [1., 0., 0., 0.],\n",
       "       [0., 0., 1., 0.]])"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Accuracy of classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_count = 0\n",
    "for i in range(len(bert_predicted)):\n",
    "    y = y_test[i].argmax()\n",
    "    \n",
    "    if bert_predicted[i].item() == y:\n",
    "        correct_count = correct_count + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.962325564879804"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "correct_count / len(bert_predicted)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Confusion matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.utils.multiclass import unique_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_confusion_matrix(y_true, y_pred, classes,\n",
    "                          normalize=False,\n",
    "                          title=None,\n",
    "                          cmap=plt.cm.Blues):\n",
    "    \"\"\"\n",
    "    This function prints and plots the confusion matrix.\n",
    "    Normalization can be applied by setting `normalize=True`.\n",
    "    \"\"\"\n",
    "    if not title:\n",
    "        if normalize:\n",
    "            title = 'Normalized confusion matrix'\n",
    "        else:\n",
    "            title = 'Confusion matrix, without normalization'\n",
    "\n",
    "    # Compute confusion matrix\n",
    "    cm = confusion_matrix(y_true, y_pred)\n",
    "    # Only use the labels that appear in the data\n",
    "    classes = classes[unique_labels(y_true, y_pred)]\n",
    "    if normalize:\n",
    "        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "        print(\"Normalized confusion matrix\")\n",
    "    else:\n",
    "        print('Confusion matrix, without normalization')\n",
    "\n",
    "    print(cm)\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "    ax.figure.colorbar(im, ax=ax)\n",
    "    # We want to show all ticks...\n",
    "    ax.set(xticks=np.arange(cm.shape[1]),\n",
    "           yticks=np.arange(cm.shape[0]),\n",
    "           # ... and label them with the respective list entries\n",
    "           xticklabels=classes, yticklabels=classes,\n",
    "           title=title,\n",
    "           ylabel='True label',\n",
    "           xlabel='Predicted label')\n",
    "\n",
    "    # Rotate the tick labels and set their alignment.\n",
    "    plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n",
    "             rotation_mode=\"anchor\")\n",
    "\n",
    "    # Loop over data dimensions and create text annotations.\n",
    "    fmt = '.2f' if normalize else 'd'\n",
    "    thresh = cm.max() / 2.\n",
    "    for i in range(cm.shape[0]):\n",
    "        for j in range(cm.shape[1]):\n",
    "            ax.text(j, i, format(cm[i, j], fmt),\n",
    "                    ha=\"center\", va=\"center\",\n",
    "                    color=\"white\" if cm[i, j] > thresh else \"black\")\n",
    "    fig.tight_layout()\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = np.array(['B', 'E', 'T', 'M'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "confusion_matrix_predicted = [element.item() for element in bert_predicted.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 3,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 3,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 3,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 3,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 1,\n",
       " 3,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 3,\n",
       " 2,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 2,\n",
       " 1,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 3,\n",
       " 1,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 3,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 3,\n",
       " 1,\n",
       " 3,\n",
       " 2,\n",
       " 2,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 3,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 2,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " 2,\n",
       " 1,\n",
       " 0,\n",
       " 1,\n",
       " 0,\n",
       " ...]"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix_predicted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 2, ..., 2, 0, 2])"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test.argmax(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Normalized confusion matrix\n",
      "[[0.98230992 0.00882857 0.0065226  0.00233891]\n",
      " [0.00698282 0.94709335 0.03919805 0.00672579]\n",
      " [0.00748544 0.03479346 0.9549025  0.00281859]\n",
      " [0.01202712 0.02667833 0.00885633 0.95243822]]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAToAAAEYCAYAAADMJjphAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VOX1+PHPSUaIIEsiUCFB2UlIChjCoqiAdWEJ4AKKyirVb1VQsFrXIqV1BRVF/VlbN8AKoiKEIKBWrFARAoqyqIAJmoQKBAtqNZjh/P64NzEbmYHMZCaT8/Z1X+bOfea558m9nDx3e66oKsYYE8miQh2AMcYEmyU6Y0zEs0RnjIl4luiMMRHPEp0xJuJZojPGRDxLdBFKRE4UkQwROSgii6pRz1UisiqQsYWCiLwpIuNCHYcJDUt0ISYiV4pIloh8LyJ73H+QZwWg6hHAr4CTVXXk8Vaiqi+p6gUBiKcMEekvIioir5f7vJv7+Wo/65kuIvN9lVPVQar64nGGa2o5S3QhJCI3A7OB+3CS0qnAU8DwAFR/GvCFqhYFoK5g2QecKSInl/psHPBFoFYgDtvP6zpVtSkEE9AE+B4YWUWZ+jiJMN+dZgP13WX9gVzg98BeYA8wwV32J+Aw8LO7jonAdGB+qbrbAAp43PnxwJfAd0A2cFWpz9eU+t6ZwAbgoPv/M0stWw38GVjr1rMKaHaUthXH/zRwg/tZtPvZNGB1qbKPAV8Dh4CNwNnu5wPLtXNzqTjudeP4EejgfvZbd/n/A14tVf+DwDuAhHq/sCk4k/2lC50zgBhgcRVl7gL6AN2BbkAv4O5Sy0/BSZjxOMnsSRGJVdV7cHqJC1X1JFV9tqpARKQh8DgwSFUb4SSzjyspFwdkumVPBh4BMsv1yK4EJgAtgHrALVWtG5gLjHV/vhDYipPUS9uA8zuIA/4BLBKRGFVdUa6d3Up9ZwxwLdAI2F2uvt8DXUVkvIicjfO7G6du1jORxxJd6JwM7NeqDy2vAmao6l5V3YfTUxtTavnP7vKfVXU5Tq+m83HGcwRIEZETVXWPqm6tpMwQYIeqzlPVIlV9GfgMGFqqzPOq+oWq/gi8gpOgjkpV/w3EiUhnnIQ3t5Iy81W1wF3nwzg9XV/tfEFVt7rf+blcff8DRuMk6vnAZFXN9VGfqcUs0YVOAdBMRDxVlGlF2d7IbvezkjrKJcr/AScdayCq+gNwOfA7YI+IZIpIoh/xFMcUX2r+P8cRzzxgEjCASnq4IvJ7EdnuXkH+L04vtpmPOr+uaqGqrsc5VBechGwimCW60PkA+Am4qIoy+TgXFYqdSsXDOn/9ADQoNX9K6YWqulJVzwda4vTS/uZHPMUx5R1nTMXmAdcDy93eVgn30PI24DIgVlWb4pwflOLQj1JnlYehInIDTs8wH/jD8YduagNLdCGiqgdxTro/KSIXiUgDETlBRAaJyENusZeBu0WkuYg0c8v7vJXiKD4GzhGRU0WkCXBH8QIR+ZWIDHPP1RXiHAJ7K6ljOdDJvSXGIyKXA12AZccZEwCqmg30wzknWV4joAjnCq1HRKYBjUst/wZocyxXVkWkE/AXnMPXMcAfRKTKQ2xTu1miCyFVfQS4GecCwz6cw61JwBtukb8AWcAnwKfAJvez41nXW8BCt66NlE1OUTgn6POBAzhJ5/pK6igA0t2yBTg9oXRV3X88MZWre42qVtZbXQm8iXPLyW6cXnDpw9Lim6ELRGSTr/W4pwrmAw+q6mZV3QHcCcwTkfrVaYMJX2IXmowxkc56dMaYiGeJzhgT8SzRGWMiniU6Y0zEq+pm1ZARz4kq9RqFOoyA6550aqhDMMdIfBepdXbvzmH//v0BbVp049NUi370WU5/3LdSVQcGct3+CM9EV68R9TtfFuowAm7tujmhDsEcI5HIS3V9e6cFvE4t+tGvf7M/ffykrydagiIsE50xprYRCOPRsCzRGWOqT4Co6FBHcVSW6IwxgRHGh/mW6IwxAWCHrsaYusB6dMaYiCZYj84YE+nELkYYY+oAO3Q1xkQ2uxhhjIl0gvXojDF1gPXojDGRzQ5djTGRToBou+pqjIl0do7OGBPZ7NDVGFMXWI/OGBPxrEdnjIloYo+AGWPqAjt0NcZENrsYYYypC8K4Rxe+Kbiazj8zic2L/8iWJfdwy4TzKyw/tWUsy5+ezPqFd7DybzcR36JpybJ7bxrOxlfv4qPX7ubhP4yoybB9WrVyBd2SE0lJ6sishx6osLywsJAxV44iJakj5/Ttw+6cHAAKCgoYeP65NI9txNSbJtVw1P6J1LatWrmCrsmdSU7swMyjtGv0lZeTnNiBs8/sXdIugJkP3k9yYge6JnfmrVUrazDqY1Q8Hp2vKUSCvmYR8YrIxyKyWUQ2iciZwV5nVJQw+/bLGD7pKU6/9C+MHNiDxHanlClz/9SLeSlzPb0uv5/7nnmTGZOHAdCnW1vO6N6OnpfdR4+R99Ij+TTO7tEx2CH7xev1MvWmSbyRsZxNm7eyaOECtm/bVqbMC88/S9PYpmzZvoPJN07h7jtvByAmJoZp02dw34MzQxG6T5HaNq/Xy5Qbb2BJxpt89Mk2Fi14uWK7nnuW2KaxbP1sJ5Nvmspdd94GwPZt21i0cAGbNm9l6bIV3DT5erxebyia4Qep24kO+FFVu6tqN+AO4P5gr7BnSht2fb2fnLwCfi7ysmjlJtL7dy1TJrFdS1Z/+DkA7234gvT+vwZAFerXO4F6J3ioX8+DxxPN3gOHgh2yX7I2rKd9+w60bdeOevXqMeKyy1mWsaRMmcyMpYweMw6Aiy8dwep330FVadiwIWf2PYuYmJhQhO5TpLZtw/qy7Rp5+agK7VqWsYSr3HZdcukIVv/TadeyjCWMvHwU9evXp03btrRv34EN69eHohn+iYr2PYUqtBpeX2Pg22CvpFWLJuR+88tq8r75lvjmTcqU+fSLPC76TXcAhp/bjcYnnUhck4Z8+Ek2/8raQfZb95K96j7e/vd2Ps/+Jtgh+yU/L4/4hISS+fj4BPLz8yop0xoAj8dD4yZNKCgoqNE4j0ekti0/P48EN2Zw2pWXl1exTOuK7crLq/jd8r+TsCLiewqRmrgYcaKIfAzEAC2BcysrJCLXAtcCcMJJ1VqhUPEXquXm73h0MY/eNpLRw3qzdtNO8r75liKvl3atm9G57a/ocOHdAGQ+PZm+H7Rn7aZd1YopEFTLt6Lim+T9KROOIrVt1WpXbWqvhPdV15o8dE0EBgJzpZKtparPqGqaqqaJ58RqrTBv739J+FVsyXz8r2LJ33ewTJk9+w4y6pa/c8YVD3LPExkAHPr+J4YP6Mb6T3P44cfD/PDjYVau3UrvX7etVjyBEp+QQF5ubsl8Xl4uLVu2qqTM1wAUFRVx6OBB4uLiajTO4xGpbYuPTyDXjRmcdrVq1apima8rtis+oeJ3y/9OwkoY9+hqNAWr6gdAM6B5MNeTtXU3HU5tzmmtTuYETzQjL0wlc/UnZcqc3LRhyV/HW6++kBeXrAPg6/98y9k9OhAdHYXHE8XZqR35LPs/wQzXbz3SerJz5w5ysrM5fPgwr76ykCHpw8qUGZw+lPnzXgRg8Wuv0q//ueHbCyglUtuW1rNsuxYtXFChXUPSh/GS267XX3uVfgOcdg1JH8aihQsoLCwkJzubnTt30LNXr1A0wy8i4nMKlRq9j05EEoFoIKgnVrzeI0x98BUynrqB6CjhxSXr2P7lf/jjdUPYtO0rMt/7lHPSOjJj8jBUYc2mnUy5/xUAXn/7I/r17ETWK3eiKG/9ezvL/7UlmOH6zePx8MjsOQwbMhDvES9jx02gS3IyM6ZPI7VHGulDhzF+wkQmjh9LSlJHYmPjmDv/5ZLvJ3Zsy3eHDnH48GEyli4hI3MlSV26hLBFv4jUtnk8Hh597AmGDrkQr9fLuPFXV2zX1RO5evwYkhM7EBsbx7yXFgDQJTmZS0dexuldu+DxeJj9+JNEh+mYb86Ra/j+0ZHKzg8EdAUiXuDT4lngTlXNrOo7UQ1aaP3OlwU1rlA4sH5OqEMwxyjce4zHo2/vNDZuzApow6Lj2uqJ593js9wPiyZsVNW0QK7bH0Hv0alqeP4JMsYEVDj/UbBHwIwxAWGJzhgT8SzRGWMim7hTmLJEZ4ypNkGIiqrbNwwbY+qAQN1HJyIDReRzEdkpIrdXsvxUEXlXRD4SkU9EZLCvOi3RGWMCIhCJTkSigSeBQUAX4AoRKX9D5N3AK6p6OjAKeMpXvZbojDHVJ35OvvUCdqrql6p6GFgADC9XRnEGCAFoAuT7qtTO0RljAiJAV13jga9LzecCvcuVmQ6sEpHJQEPgPF+VWo/OGFNtxRcjfE1AMxHJKjVdW6Gqiso/vnUF8IKqJgCDgXkiVQ+dYj06Y0xg+Neh2+/jEbBcoHWp+QQqHppOxBkJCVX9QERicAYL2Xu0Sq1HZ4ypPgnYVdcNQEcRaSsi9XAuNiwtV+Yr4DcAIpKEM9blvqoqtR6dMSYgAnGOTlWLRGQSsBJnpKPnVHWriMwAslR1KfB74G8iMhXnsHa8+hidxBKdMSYgAvUImKouB5aX+2xaqZ+3AX2PpU5LdMaYahNCO7CmL5bojDHVF+YDb1qiM8YEhPXojDERzxKdMSbyhW+es0RnjAkM69EZYyKaSHiPR2eJzhgTENajO0bdk05l7brIezVgXL87Qx1CUOx55y+hDiFo6nvCt5dyvIL2gtPwzXPhmeiMMbWP9eiMMZFNLNEZYyKcMx6dJTpjTIQL4w6dJTpjTGDYoasxJrKJ9eiMMRFOwM7RGWMin/XojDGRTaxHZ4yJcIJdjDDGRDwbSt0YUweEcZ6zRGeMCQzr0RljIprYxQhjTF0Qxh06S3TGmMCwQ1djTMQL4zxnic4YEwA2Hp0xJtI5NwyHOoqjs0RnjAkAG3jTGFMH2KGrMSayhfl4dJH3LjfXqpUr6JacSEpSR2Y99ECF5YWFhYy5chQpSR05p28fdufkAFBQUMDA88+leWwjpt40qYaj9u383p3Y/PLNbHnlFm4Z06/C8lNPacryxyeyfu6NrHziGuKbNy5Z9v3797Luhcmse2Eyix4cU5Nh++XtVSvo2a0LqSmdeXTWgxWWFxYWcvWYK0hN6cx555zBV7tzyiz/+uuvSGjehDmzH66hiP0TqftiacUP9fuaQqVGEp2IeEXk41LT7cFcn9frZepNk3gjYzmbNm9l0cIFbN+2rUyZF55/lqaxTdmyfQeTb5zC3Xc6IcXExDBt+gzue3BmMEM8LlFRwuxbhjH8989z+pWPMvK8biS2aVGmzP2TBvPSmx/Ra+zj3Pf8O8y4bmDJsh8Lf6bP+Dn0GT+HkbfNq+nwq+T1erl16o0semMZ6zZ9ymuLFvLZ9rLbbN4Lz9GkaSybtnzOdZOnMP3uO8osv+sPv+e8CwYSTiJ1X6xMnU90wI+q2r3UVPHPWgBlbVhP+/YdaNuuHfXq1WPEZZezLGNJmTKZGUsZPWYcABdfOoLV776DqtKwYUPO7HsWMTExwQzxuPTs0ppduQXk5H/Lz0VeFr29mfSzk8qUSWzTgtVZOwF4b+OXFZaHq41Z62nXvj1t2jrb7JIRl7F82dIyZd7MXMoVo52e6PCLL+W91f9E1Xkdc+bSJZzWti2JSV1qPPaqROq+WJmoKPE5hSy2kK05iPLz8ohPSCiZj49PID8/r5IyrQHweDw0btKEgoKCGo3zWLVq3pjcbw6WzOftO0R88yZlyny6cw8XDUgBYHi/ZBo3jCGucQMAYup5WPPsDbz3zHUMPSe8EsKe/Hzi41uXzLeKT2BPfn6ZMvmlyng8Hho3bsKBggJ++OEHHnvkIW67c1qNxuyPSN0XK3DP0fmaQqWmLkacKCIfl5q/X1UXli4gItcC1wK0PvXUaq2s+K98ufqPuUy4qSy68u2444nlPHrzMEYP7sHaj7PJ23uQIq8XgE6XPMie/d/RplUsK+Zcw5Zd/yE770ANRO6bX9vjKGUe+Mt0rps8hZNOOilI0R2/SN0XyxMbjw5wD12rKqCqzwDPAKT2SKu45Y9BfEICebm5JfN5ebm0bNmqkjJfk5CQQFFREYcOHiQuLq46qw26vH2HSPjVLz24+OaNyd9/qEyZPfu/Y9SdLwHQ8MR6XNQ/hUM/FJYsA8jJ/5Z/bfqS7p1ahU2iaxUfT17e1yXz+Xm5nNKyZaVl4ou32aGDxMbFkbVhPUsWv849d93OwYP/JSoqivr1Y7j2uhtquhkVROq+WJlA5TkRGQg8BkQDf6/sVJeIXAZMBxTYrKpXVlVnRB669kjryc6dO8jJzubw4cO8+spChqQPK1NmcPpQ5s97EYDFr71Kv/7nhvVfJICs7bl0SGjGaS1jOcETzcjzupG5ZnuZMic3aVDSjlvH9ufFZVkANG0UQ70TokvKnNH1NLZn763ZBlQhtUdPdu3cye4cZ5u9/uorDBoytEyZgYOH8vJ85yLKksWvcU6/AYgIb779Hp98totPPtvFdTfcyM233h4WSQ4id1+sTJSIz8kXEYkGngQGAV2AK0SkS7kyHYE7gL6qmgxM8VVvRN5H5/F4eGT2HIYNGYj3iJex4ybQJTmZGdOnkdojjfShwxg/YSITx48lJakjsbFxzJ3/csn3Ezu25btDhzh8+DAZS5eQkbmSpC6hP6fl9R5h6iNLyXj0aqKjhReXZbE9ey9//O15bPosj8w12zkntR0zfnchqrDm42ymPOyc+E48rQVzbruYI0eUqChh1rz3+CwnfBKdx+PhoUce49Jhg/F6vVw1djxJXZK5b8Y9dE9NY3D6UMaMv5rfTRxHakpnYmNjeXbuP0Idtk+Rui9WJkC5uRewU1W/dOqUBcBwoPSl6muAJ1X1WwBV9bkjS2XnB9wVNK50gUtVD1W1vFxdXuDTUh+tUNWj3mKS2iNN167b4G/1tUZcvztDHUJQ7HnnL6EOIWjqeyLvoKdvn55s2pgV0C5jk9OS9MzbX/BZbsX1fXYD+0t99Ix72goAERkBDFTV37rzY4DeqjqpVJk3gC+AvjiHt9NVdUVV662qR7cV5/i39C+keF4Bv68YqGq0v2WNMbWTn4fb+1U1rapqKvmsfG/MA3QE+gMJwPsikqKq/z1apUdNdKra+mjLjDGmvAAduuYCpXNPApBfSZl1qvozkC0in+MkvqMeBvrVLxeRUSJyp/tzgoj0OJbIjTGRTXBvMfHxnx82AB1FpK2I1ANGAUvLlXkDGAAgIs2ATsCXVVXqM9GJyBNupcUPR/4PeNqfiI0xdUeU+J58UdUiYBKwEtgOvKKqW0VkhogUX65eCRSIyDbgXeBWVa3yDmt/rrqeqaqpIvKRG8gBN9MaY4xDAveIl6ouB5aX+2xaqZ8VuNmd/OJPovtZRKJwTwiKyMnAEX9XYIyJfAJ+3ScXKv6co3sSeA1oLiJ/AtYAFcfQMcbUabX6WVdVnSsiG4Hz3I9GquqW4IZljKltwvlpDn+fjIgGfsY5fI28OyiNMdUS6h6bL/5cdb0LeBlohXNPyz9E5I6qv2WMqWuiRXxOoeJPj2400ENV/wcgIvcCG4H7gxmYMaZ2qe2HrrvLlfPg4+Y8Y0zd4lx1DXUUR3fURCcij+Kck/sfsFVEVrrzF+BceTXGGEeI3wnhS1U9uuIrq1uBzFKfrwteOMaY2iqM81yVD/U/W5OBGGNqt9raowNARNoD9+KM9lnyOiJV7RTEuIwxtYgA0WF8ks6fe+JeAJ7Hacsg4BVgQRBjMsbUQuLHFCr+JLoGqroSQFV3qerduEOkGGMMOOfnAvHOiGDx5/aSQnEOvneJyO+APKCFj+8YY+qYMD5F51eimwqcBNyIc66uCXB1MIMyxtQ+tfpihKp+6P74Hb8MvmmMMSUECeuLEVXdMLyYii+lKKGqlwQlImNM7RPmD/VX1aN7osaiqCP+E6GvBTxlQOSO8fDtmodCHULABSsf1cpDV1V9pyYDMcbUbuE8fpu/49EZY8xRCbW0R2eMMccijK9F+J/oRKS+qhYGMxhjTO0kUssfARORXiLyKbDDne8mInOCHpkxplYJxHtdgxabH2UeB9KBAgBV3Yw9AmaMKadWvwUMiFLV3eVONHqDFI8xphYK9/e6+pPovhaRXoCKSDQwGfgiuGEZY2qb2n57yXU4h6+nAt8Ab7ufGWMM4NxaEs4XI/x51nUvMKoGYjHG1GJhfOTq1wjDf6OSZ15V9dqgRGSMqZXCuEPn16Hr26V+jgEuBr4OTjjGmNqo1l+MUNWFpedFZB7wVtAiMsbUSmGc547rEbC2wGmBDsQYU4uF+IZgX/w5R/ctv5yjiwIOALcHMyhjTO0iQHQYd+mqTHTuuyK64bwnAuCIqh51ME5jTN0Vzj26Ku/xc5PaYlX1upMlOWNMpUTE5xQq/tzMvF5EUoMeiTGm1nKuugbmoX4RGSgin4vIThE56mkyERkhIioiab7qrOqdER5VLQLOAq4RkV3AD26bVFUt+RljHAF6aN99zPRJ4HwgF9ggIktVdVu5co1w3kz4YcVaKqrqHN16IBW46LgiNsbUGQJ4AnOSrhewU1W/BBCRBcBwYFu5cn8GHgJu8afSqhKdAKjqrmMO1RhT5/jZo2smIlml5p9R1WdKzcdT9oGEXKB32fXI6UBrVV0mItVOdM1F5OajLVTVR/xZgTGmLhCi/Hu/2H5VreqcWmWVlFwEFZEo4FFg/LFEV9XFiGjgJKDRUaawtmrlCrolJ5KS1JFZDz1QYXlhYSFjrhxFSlJHzunbh905OQAUFBQw8PxzaR7biKk3TarhqP3z9qoVpHXrwukpnXl01oMVlhcWFjJhzBWcntKZ35xzBrt35wCwccN6zurdg7N696Bv71QylrxRw5Ef3fl9OrF54a1sWfQHbhnTv8LyU09pyvI517B+/lRWPvV/xDdvUrLs+7UPsG7uFNbNncKimeNrLmg/rVq5gq7JnUlO7MDMo+yLo6+8nOTEDpx9Zu+SfRFg5oP3k5zYga7JnXlr1coajPrYOC/HCcjAm7lA61LzCUB+qflGQAqwWkRygD7AUl8XJKrq0e1R1Rl+hVYFETkZKH514ik4g3buc+d7qerh6q6jPK/Xy9SbJrFs+SriExI4+4xeDEkfRlKXLiVlXnj+WZrGNmXL9h0sWriAu++8nXn/WEBMTAzTps9g69YtbNu6JdChVZvX6+WWqTfyxrIVtIpPYMDZfRg0ZCiJSb+0bd4Lz9G0aSwfbfmc1xYtZPrdd/D8vJdJSk5h9doP8Xg8/GfPHs7qk8qgIel4PKF9R1JUlDD7losZcuPfyNt7kDXPT2bZ+9v4LGdvSZn7J6fz0pubeGn5Rvr1aM+M6wcy8U/O04k/Fv5Mn7GzQxV+lbxeL1NuvIHMN98iPiGBs/r0JL38vvjcs8Q2jWXrZzt5ZeEC7rrzNub/YyHbt21j0cIFbNq8lT35+QweeB6fbvuC6OjoELboKAL3ZMQGoKOItMW5f3cUcGXxQlU9CDQrWa3IauAWVc2iClX16AIStqoWqGp3Ve0OPA08WjwfjCQHkLVhPe3bd6Btu3bUq1ePEZddzrKMJWXKZGYsZfSYcQBcfOkIVr/7DqpKw4YNObPvWcTExAQjtGrbmLWedu3b06at07ZLR1zG8mVLy5RZnrmUK0aPAWD4xZfy3up/oqo0aNCgJKn9VPhT2LyermeX1uzK3U9O/gF+LvKy6K3NpJ+TXKZMYtsWrN6wE4D3Nu6qsDxcbVhfdl8cefmoCvvisowlXOXui5dcOoLV/3T2xWUZSxh5+Sjq169Pm7Ztad++AxvWrw9FM/wSJeJz8sW902MSsBLYDryiqltFZIaIDDvu2KpY9pvjrTTU8vPyiE9IKJmPj08gPz+vkjJOD9nj8dC4SRMKCgpqNM7jsSc/n/j4X3r2reIT2JOff9QyHo+Hxo2bcMBtW9b6D+nToyt9e3bnkceeCnlvDqBV8ybk7j1YMp+39yDxzRuXKfPpjj1cNCAFgOH9U2jcMIa4xg0AiKnnYc3zN/Le329gaJglwPz8PBISftle8fEJ5OXlVSzTuuK+mJdX8bvl9+NwIThvAfM1+UNVl6tqJ1Vtr6r3up9NU9WllZTt76s3B1UcuqrqAb+iCkOVPcBRvvfiT5lwVOnDKcfQtrRevVm38RM+/2w7110zgfMvHBjy3mtlv/byLbhjTiaP3jKc0UPSWPtxNnl7/0uR9wgAnS66nz37D9GmVRwrnryWLbv2kJ0XHrtvtfbFWraPhnFo4TPMu4hcKyJZIpK1f/8+31+oQnxCAnm5uSXzeXm5tGzZqpIyzlXsoqIiDh08SFxcXLXWWxNaxceTl/fL1ff8vFxatmx51DJFRUUcOnSQ2HJt65yYRIOGDdkeBuch8/YeJKHFLxcX4ls0IX/foTJl9uw/xKjb53HGuMe45+kVABz64aeSZQA5+Qf416Yv6d4pvoYi9y0+PoHc3F+2V15eLq1atapY5uuK+2J8QsXvlt+Pw4XgJBNfU6iETaJT1WdUNU1V05o1a16tunqk9WTnzh3kZGdz+PBhXn1lIUPSyx7eD04fyvx5LwKw+LVX6df/3LD+a1kstUdPdu3cSU6O07bXXn2FQUOGlikzaPBQXp4/D4Ali1/jnH4DEBFycrIpKioC4KuvdrPziy849bQ2Nd2ECrK259KhdTNOaxnLCZ5oRp7fjcz3y94fenKTBiXb59ZxA3gxwzlaadroROqdEF1S5oyubdie/U3NNqAKaT3L7ouLFi6osC8OSR/GS+6++Pprr9JvgLMvDkkfxqKFCygsLCQnO5udO3fQs1evUDTDNwnvZ11Df4ImCDweD4/MnsOwIQPxHvEydtwEuiQnM2P6NFJ7pJE+dBjjJ0xk4vixpCR1JDY2jrnzXy75fmLHtnx36BCHDx8mY+kSMjJXlrlKFkoej4eZjzzGpcMG4/V6GT12PEldkrl3xj2cnprG4PShjBl/Nf83cRynp3QmNjaW5+b+A4B1/17L7IcfwuM5gaioKGbNfoKTmzXzscbg83qPMHXWEjIe+y3RUVG8uGwD27O/4Y/XXMCmz3KslFYXAAARSUlEQVTJfH8b56S2Z8b1g1BV1nyczZSZiwFIbNOCObddwhFVokSYNffdMldrQ83j8fDoY08wdMiFeL1exo2/uuK+ePVErh4/huTEDsTGxjHvpQUAdElO5tKRl3F61y54PB5mP/5keF5xdYVzN0FqckASEZkOfK+qs6oql9ojTdeu21AzQdWgw0VHQh1CUJwy4I5QhxA03655KNQhBFzf3mls3JgV0LzUrktX/fO85T7LjU5rvdHHDcNBUaM9OlWdXpPrM8bUnHA+8xORh67GmJoW2nNwvliiM8ZUW/FV13Blic4YExDWozPGRDap5e91NcYYX+zQ1RhTJ9ihqzEm4oVvmrNEZ4wJkDDu0FmiM8ZUn3OOLnwznSU6Y0wA+DewZqhYojPGBEQY5zlLdMaY6rNDV2NM5PP/LV8hYYnOGBMQluiMMRFNgOgwznSW6IwxASF2js4YE+nCuENnic4YExjWozPGRDQB/Hw/dUhYojPGBIBYj84YE+HEenTGmAjnHLqGb6YL20RXg6+brTHhvCNUx4H3Hwx1CEETe9YfQh1CwBV+nhuUesN57w7bRGeMqWXCONNZojPGBIRdjDDGRDy7GGGMiXyW6IwxkUywQ1djTKQL8/Howvmds8aYWkT8mPyqR2SgiHwuIjtF5PZKlt8sIttE5BMReUdETvNVpyU6Y0xgBCDTiUg08CQwCOgCXCEiXcoV+whIU9WuwKvAQ77qtURnjAkA5y1gviY/9AJ2quqXqnoYWAAML11AVd9V1f+5s+uABF+VWqIzxlSbP505N801E5GsUtO15aqKB74uNZ/rfnY0E4E3fcVnFyOMMYHh30m4/aqadoy1VPpAqIiMBtKAfr5WaonOGBMQAbq9JBdoXWo+AcivsC6R84C7gH6qWuirUjt0NcYEhIjvyQ8bgI4i0lZE6gGjgKVl1yOnA38FhqnqXn8qtR6dMab6AnQfnaoWicgkYCUQDTynqltFZAaQpapLgZnAScAicVb6laoOq6peS3TGmIAI1JMRqrocWF7us2mlfj7vWOu0RGeMqTYhvJ+MsERnjAmIMM5zluiMMQESxpnOEp0xJiDC+VUBluiMMQERvmnOEp0xJlDCONNF7A3Dq1auoHtKIr9O6sismQ9UWF5YWMjYq0bx66SO9DurD7tzcgAoKChg0AXn0iKuETffNKmGo/bPW6tWkNo1iW7JnXhkZsU3cBUWFjJ+9Ci6JXdiwNlnsHt3DgBZG9bTt3cqfXuncmav08lYsriGI6/aqpUr6JacSEpSR2Y9VPk2G3PlKFKSOnJO37LbbOD559I8thFTw3Cbnd+nE5sX3sqWRX/gljH9Kyw/9ZSmLJ9zDevnT2XlU/9HfPMmJcu+X/sA6+ZOYd3cKSyaOb7mgj5GxQNv+vovVIKW6ERERWReqXmPiOwTkWXBWmcxr9fLzTdNYvHS5WzcvJVFCxewffu2MmVefP5ZmjZtyqfbdzDpxin88S5n2KuYmBj+eM8M7ntgZrDDPC5er5ffT5nMa0sy2fDRFl5dtIDPyrVt7gvP0TQ2ls1bv+CGyTdxj9u2LskpvLd2PWs/3MTrS5Zz0+TrKCoqCkUzKvB6vUy9aRJvZCxnU/E221a2XS88/yxNY5uyZfsOJt84hbvv/GWbTZs+g/seDL9tFhUlzL7lYoZPfZbTr3iYkRd0J7FNizJl7p+czktvbqLX6Ee579m3mXH9wJJlPxb+TJ+xs+kzdjYjb32hhqM/Bn48FRHKU3jB7NH9AKSIyInu/PlAXhDXVyJrw3rate9A23btqFevHiMuu5xlGUvKlFmWsZSrxowD4OJLRrD63XdQVRo2bMiZfc+ifkxMTYR6zJy2tadtW6dtl468nMxlZZ6QIXPZEq64aiwAF10ygtWr/4mq0qBBAzwe52zFT4U/IWF08jhrw3ra+9hmmRlLGV28zS6tuM1iwnCb9ezSml25+8nJP8DPRV4WvbWZ9HOSy5RJbNuC1Rt2AvDexl0VltcWgRp4MxiCfej6JjDE/fkK4OUgrw+A/Pw8Elr/MkRVfHwCe/LyKpZJcJ4d9ng8NG7chIKCgpoIr1r2lIoboFV8PPnl2rYnP79C2w64bduw/kN6pf6aM9K6Mfvxp0oSX6jl5+URn1B2m+Xn51VSplS7moT/NmvVvAm5ew+WzOftPUh888Zlyny6Yw8XDUgBYHj/FBo3jCGucQMAYup5WPP8jbz39xsYGtYJUBDxPYVKsBPdAmCUiMQAXYEPg7w+AFQrjupS4ZfsT5kw5E/bKitTfNzQs1dv1m/6lNVrPuThmQ/y008/BSXOY3W87Qr3bVZZeOVbccecTM5ObccHL97E2ae3I2/vfynyHgGg00X3c9aExxk37WVmTh1K2/i44Ad9nOrqoSuq+gnQBqc3t7yqsiJybfFgfPv376vWeuPjE8j9OrdkPi8vl1NatSpTplV8Arm5zvh+RUVFHDp0kLi48N2JipWOG5xeTssKbYv32bbOiUk0bNiQbVu3BD9oP8QnJJCXW3abtWzZqpIypdp1MPy3Wd7egyS0+OXiQnyLJuTvO1SmzJ79hxh1+zzOGPcY9zy9AoBDP/xUsgwgJ/8A/9r0Jd07VTUGZegcw8CbIVETV12XArPwcdiqqs+oapqqpjVr1rxaK+yR1pNdO3eQk53N4cOHefWVhQxJLzu4wZD0obw070UAFr/+Kv36nxv2vQNw2vblzp3k5Dhte23RQgYPGVqmzOAhw3j5pbkAvPH6q/TrNwARIScnu+Tiw1e7d7Pji8857bQ2Nd2ESvVI68lOH9tscPpQ5hdvs9dqxzbL2p5Lh9bNOK1lLCd4ohl5fjcy3y97keXkJg1K2nHruAG8mJEFQNNGJ1LvhOiSMmd0bcP27G9qtgHHIowzXU2coHkOOKiqn4pI/xpYHx6Ph4dnz2F4+kC8Xi9jx0+gS5dk/vynaaSmpjFk6DDGTZjIbyeM5ddJHYmNi+PFeb/k4aRObfnu0CEOHz5MRsYSlmauJCmp/Ps5QsPj8TDz0ce5eOggvF4vY8ZNIKlLMn+ZcQ+pqT0YnD6MseOv5tqrx9ItuROxsXE8P+8fAHzw7zU8OushTjjhBKKionjksSc4uVmzELfI4fF4eGT2HIYNGYj3iJex4ybQJTmZGdOnkdojjfShwxg/YSITx48lJakjsbFxzJ3/yzZL7Fhqmy1dQkbmSpK6hH6beb1HmDprCRmP/ZboqCheXLaB7dnf8MdrLmDTZ7lkvr+Nc1LbM+P6Qagqaz7OZspM57afxDYtmHPbJRxRJUqEWXPf5bMcv4ZfC4lwfq+rVHo+JxAVi3yvqieV+6w/cIuqplf13dQeabrmgw1BiSuUvEeC87sONU90+O7g1RV39m2hDiHgCrfM5cj3/wnoRuvavYdm/vPfPsudenLMRh9DqQdF0Hp05ZOc+9lqYHWw1mmMCZEwf4F1eNxbYIyJAOGb6SzRGWOqzQbeNMbUCWGc5yzRGWMCw3p0xpiIF873NFqiM8YERPimOUt0xpgACPWzrL5YojPGBEQ4Pxlhic4YExjhm+cs0RljAiPKEp0xJrKF9p0QvliiM8ZUW7g/GRGxbwEzxphi1qMzxgREOPfoLNEZYwLCztEZYyKaiF11NcbUBZbojDGRzg5djTERzy5GGGMiXhjnOUt0xpjAsPHojDERLdyfjAjae12rQ0T2AbtraHXNgP01tK6aZO2qXWqyXaepavNAVigiK3Da4Mt+VR0YyHX7IywTXU0SkaxQvFA32KxdtUuktitc2LOuxpiIZ4nOGBPxLNHBM6EOIEisXbVLpLYrLNT5c3TGmMhnPTpTa4jIqaGOwdROluhMrSAig4F3RCQ+1LGY2scSnQl7InIhMAsYo6p5IhJR+62IxIY6hkgXUTvMsRCRkyNxBxORHiLSJ9RxBIqIXADMBbYBBwBU9YiE8/NGx8Bt31vu/02Q1MlE5x4GvQn8VUT+Eup4AkVEBuFcvftfqGMJBBH5DfAEcDPwb+BqETkLQFU1QpJdZyAFuEVELgp1MJGqzj3rKiIDgTuBe3EeM7tZRE5U1R9DG1n1uO26G7hDVT9xe6tNVTU7xKFVxyFgvKr+W0Q6A6OBISKiqrq2ONlp7b514GWgHfAVMFZETlDVRSGOKeLUqR6diMQBy4GHVXUJUA84H5glIn8tVa5W9RRKtWumqq4SkfbAUqBWX6VU1Q1ukotS1c9xDmF/BtJF5Ey3TK1LciLSVUS6urMHgMNAMvD/gNEicmnIgotQdSrRqeoBYCgwTUS64fTqngEeALqJyMtuuVr1j6dcu7oCTwNvqOp7oY0sMFT1iPv/HcA84CdglIj0Dmlgx0FETgY+BpaJyAigB3AXUIjz7/EfOD27K0IXZeSpU4kOQFUzgTuAj4B3VPUeVf0aOA9o7u6ItY7brjtx/hG9o6oPi0g0OIe1IjIgpAEGiJvsFgJ7gC9DHM4xU9UCnH0tAegKDMTpqf4PaK6qC4HFwHARaRSyQCNMnX0yQkTOxznR3VtV/ysiE4BrgAtV9bvQRnf83HbNAfq47RoPXA9cXsvP15Xhnsv6OdRxHC/3QstzQCowArgSyAUmAPUBavN+GG7qbKKDkquUM4GngFHA9aq6JbRRVZ/brodw2nUl8DtV3RraqEx57tX/B4EzVPV7EWkbSX+Mwkmdu+pamqq+6R7evQ6cHinJIFLbFWlUdbl73WuDiPQtTnIRcCU57NTpHl0xEWmgqhFx71lpkdquSCMiw4F7gDSca2H2jzLALNEZEwZE5CRV/T7UcUQqS3TGmIhX524vMcbUPZbojDERzxKdMSbiWaIzxkQ8S3S1kIh4ReRjEdkiIotEpEE16uovIsvcn4eJyO1VlG0qItcfxzqmi8gt/n5erswL7jOh/q6rjYjU+pu+TWBZoqudflTV7qqagjPyxe9KLxTHMW9bVV2qqg9UUaQpzuNkxtQqluhqv/eBDm5PZruIPAVsAlqLyAUi8oGIbHJ7fidByUP+n4nIGuCS4opEZLyIPOH+/CsRWSwim93pTJxRXtq7vcmZbrlbRWSDiHwiIn8qVdddIvK5iLyNM7hklUTkGreezSLyWrle6nki8r6IfCEi6W75aBGZWWrd/1fdX6SJXJboajER8QCDgE/djzoDc1X1dOAHnIE4z1PVVCALZ5DRGOBvOMM6nQ2ccpTqHwfeU9VuOA+ebwVuB3a5vclb3eG/OwK9gO5ADxE5R0R64Dw7fDpOIu3pR3NeV9We7vq2AxNLLWsD9AOGAE+7bZgIHFTVnm7914hIWz/WY+qgOv2say12ooh87P78PvAs0ArYrarr3M/7AF2Ate7zlPWAD4BEINsd7ggRmQ9cW8k6zgXGAqiqFzgoFd+xcYE7feTOn4ST+BoBi4sfPxORpX60KUWcYe2buvWsLLXsFXdMuh0i8qXbhguArqXO3zVx1/2FH+sydYwlutrpR1XtXvoDN5n9UPoj4C1VvaJcue5AoB6HEeB+Vf1rmQ9FphzHOl4ALlLVze7QUv1LLStfl7rrnqyqpRMiItLmGNdr6gA7dI1c64C+ItIBnAf8RaQT8BnQ1h1uHeBoI9m+A1znfjdaRBoD3+H01oqtxHlhTfG5v3gRaQH8C7hYRE50B48c6ke8jYA9InICcFW5ZSNFJMqNuR3wubvu69zyiEgnEWnox3pMHWQ9ugilqvvcntHLIlLf/fhuVf1CRK4FMkVkP7AG5y1U5d0EPCMiEwEvcJ2qfiAia93bN950z9MlAR+4PcrvgdGquklEFuKMdrwb5/Dalz8CH7rlP6VsQv0ceA/4Fc7Yej+JyN9xzt1tEmfl+wB7i5aplD3Ub4yJeHboaoyJeJbojDERzxKdMSbiWaIzxkQ8S3TGmIhnic4YE/Es0RljIt7/B8stmLIEKNNVAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot normalized confusion matrix\n",
    "plot_confusion_matrix(y_test.argmax(axis=1), confusion_matrix_predicted, classes=class_names, normalize=True,\n",
    "                      title='Confusion Matrix')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
